"""
Phase 1: Data Assembly — Asset Returns Panel
=============================================
Downloads BIS REER, BIS house prices, WDI stock market cap.
Merges with full_panel.csv to create asset_panel.csv.
"""

import io
import json
import time
import zipfile
from pathlib import Path

import numpy as np
import pandas as pd
import requests

# ── Paths ──────────────────────────────────────────────────────────────
PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/asset_returns")
MULTILATERAL_DIR = PROJECT_DIR.parent / "multilateral"
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
TABLES_DIR.mkdir(parents=True, exist_ok=True)

# ── BIS country code → ISO3 mapping ───────────────────────────────────
BIS_TO_ISO3 = {
    "AU": "AUS", "AT": "AUT", "BE": "BEL", "BR": "BRA", "BG": "BGR",
    "CA": "CAN", "CL": "CHL", "CN": "CHN", "CO": "COL", "HR": "HRV",
    "CY": "CYP", "CZ": "CZE", "DK": "DNK", "EE": "EST", "FI": "FIN",
    "FR": "FRA", "DE": "DEU", "GR": "GRC", "HK": "HKG", "HU": "HUN",
    "IS": "ISL", "IN": "IND", "ID": "IDN", "IE": "IRL", "IL": "ISR",
    "IT": "ITA", "JP": "JPN", "KR": "KOR", "LV": "LVA", "LT": "LTU",
    "LU": "LUX", "MY": "MYS", "MT": "MLT", "MX": "MEX", "MA": "MAR",
    "NL": "NLD", "NZ": "NZL", "NO": "NOR", "PE": "PER", "PH": "PHL",
    "PL": "POL", "PT": "PRT", "RO": "ROU", "RU": "RUS", "SA": "SAU",
    "RS": "SRB", "SG": "SGP", "SK": "SVK", "SI": "SVN", "ZA": "ZAF",
    "ES": "ESP", "SE": "SWE", "CH": "CHE", "TH": "THA", "TR": "TUR",
    "GB": "GBR", "US": "USA", "4T": None, "XM": None,  # aggregates
}


# ── Download helpers ───────────────────────────────────────────────────

def download_bis_zip(url, label):
    """Download and extract CSV from a BIS bulk ZIP."""
    print(f"  Downloading {label} from {url} ...")
    resp = requests.get(url, timeout=120)
    resp.raise_for_status()
    with zipfile.ZipFile(io.BytesIO(resp.content)) as zf:
        csv_names = [n for n in zf.namelist() if n.endswith(".csv")]
        if not csv_names:
            raise ValueError(f"No CSV found in {url}")
        print(f"    Extracting {csv_names[0]} ...")
        with zf.open(csv_names[0]) as f:
            df = pd.read_csv(f, low_memory=False)
    print(f"    Shape: {df.shape}")
    return df


def parse_bis_property_prices(raw):
    """Parse BIS residential property price CSV into (iso3, year, value)."""
    # BIS SPP CSV has columns: FREQ, REF_AREA, UNIT_MEASURE, VALUE, ...
    # with time periods as columns (e.g., 2000-Q1)
    # The "col" format has time periods as columns

    print("  Parsing BIS property prices ...")

    # Identify time columns (like 2000-Q1, 2020-Q4)
    time_cols = [c for c in raw.columns if len(c) == 7 and c[4] == '-' and c[5] == 'Q']
    if not time_cols:
        # Try alternate format with date columns
        time_cols = [c for c in raw.columns if c.startswith("19") or c.startswith("20")]

    id_cols = [c for c in raw.columns if c not in time_cols]

    # Filter for real (CPI-deflated) residential property prices
    # UNIT_MEASURE: "771" = real (CPI-deflated), index
    # VALUE field = "R" for real
    unit_col = None
    for candidate in ['UNIT_MEASURE', 'Unit of measure']:
        if candidate in raw.columns:
            unit_col = candidate
            break

    value_col = None
    for candidate in ['VALUE', 'Value']:
        if candidate in raw.columns:
            value_col = candidate
            break

    ref_area_col = None
    for candidate in ['REF_AREA', 'Reference area']:
        if candidate in raw.columns:
            ref_area_col = candidate
            break

    if ref_area_col is None:
        print(f"    Available columns: {list(raw.columns[:20])}")
        raise ValueError("Cannot find reference area column")

    # Try to melt if time columns exist
    if time_cols:
        # Filter for real prices
        if unit_col and '771' in raw[unit_col].astype(str).values:
            filtered = raw[raw[unit_col].astype(str) == '771']
        elif unit_col and 'Real' in raw[unit_col].astype(str).values:
            filtered = raw[raw[unit_col].astype(str).str.contains('Real', case=False, na=False)]
        else:
            filtered = raw  # take all

        melted = filtered.melt(id_vars=id_cols, value_vars=time_cols,
                               var_name='period', value_name='hpi')
        melted['hpi'] = pd.to_numeric(melted['hpi'], errors='coerce')
        melted = melted.dropna(subset=['hpi'])

        # Parse quarter → year
        melted['year'] = melted['period'].str[:4].astype(int)
        melted['iso2'] = melted[ref_area_col].astype(str).str.strip()
        melted['iso3'] = melted['iso2'].map(BIS_TO_ISO3)
        melted = melted.dropna(subset=['iso3'])

        # Annual average
        annual = melted.groupby(['iso3', 'year'])['hpi'].mean().reset_index()
    else:
        # Long format — look for TIME_PERIOD and OBS_VALUE
        print("    Attempting long format parse ...")
        tp_col = next((c for c in raw.columns if 'TIME' in c.upper() and 'PERIOD' in c.upper()), None)
        obs_col = next((c for c in raw.columns if 'OBS' in c.upper() and 'VALUE' in c.upper()), None)
        if tp_col is None or obs_col is None:
            print(f"    Could not find TIME_PERIOD/OBS_VALUE. Cols: {list(raw.columns[:30])}")
            return pd.DataFrame(columns=['iso3', 'year', 'hpi'])

        raw['obs_val'] = pd.to_numeric(raw[obs_col], errors='coerce')
        raw = raw.dropna(subset=['obs_val'])

        # Filter real
        if unit_col:
            mask = raw[unit_col].astype(str).str.contains('771|Real|real', na=False)
            if mask.sum() > 0:
                raw = raw[mask]

        raw['year'] = raw[tp_col].astype(str).str[:4].astype(int)
        raw['iso2'] = raw[ref_area_col].astype(str).str.strip()
        raw['iso3'] = raw['iso2'].map(BIS_TO_ISO3)
        raw = raw.dropna(subset=['iso3'])

        annual = raw.groupby(['iso3', 'year'])['obs_val'].mean().reset_index()
        annual.rename(columns={'obs_val': 'hpi'}, inplace=True)

    print(f"    HPI: {annual['iso3'].nunique()} countries, {len(annual)} obs")
    return annual


def parse_bis_reer(raw):
    """Parse BIS effective exchange rate CSV (optimized for wide format)."""
    print("  Parsing BIS REER ...")

    # Identify columns
    ref_area_col = None
    for candidate in ['REF_AREA', 'Reference area']:
        if candidate in raw.columns:
            ref_area_col = candidate
            break

    eer_type_col = None
    for candidate in ['EER_TYPE', 'EER type']:
        if candidate in raw.columns:
            eer_type_col = candidate
            break

    basket_col = None
    for candidate in ['EER_BASKET', 'EER basket']:
        if candidate in raw.columns:
            basket_col = candidate
            break

    if ref_area_col is None:
        print(f"    Available columns: {list(raw.columns[:20])}")
        raise ValueError("Cannot find reference area column")

    # Check for time period columns (monthly like 2020-01)
    time_cols = [c for c in raw.columns if len(c) == 7 and c[4] == '-'
                 and c[5:7].isdigit()]
    if not time_cols:
        time_cols = [c for c in raw.columns if (c.startswith("19") or c.startswith("20"))
                     and len(c) >= 7]

    if time_cols:
        # Filter FIRST for real CPI-based broad REER (before any melt/stack)
        filtered = raw.copy()
        if eer_type_col and 'R' in filtered[eer_type_col].astype(str).values:
            filtered = filtered[filtered[eer_type_col].astype(str) == 'R']
        elif eer_type_col:
            mask = filtered[eer_type_col].astype(str).str.contains('Real|real', na=False)
            if mask.sum() > 0:
                filtered = filtered[mask]

        if basket_col and 'B' in filtered[basket_col].astype(str).values:
            filtered = filtered[filtered[basket_col].astype(str) == 'B']
        elif basket_col:
            mask = filtered[basket_col].astype(str).str.contains('Broad|broad', na=False)
            if mask.sum() > 0:
                filtered = filtered[mask]

        print(f"    After filtering: {len(filtered)} rows × {len(time_cols)} time cols")

        # Map iso2 → iso3 for filtered rows
        filtered = filtered.copy()
        filtered['iso3'] = filtered[ref_area_col].astype(str).str.strip().map(BIS_TO_ISO3)
        filtered = filtered.dropna(subset=['iso3'])

        # Use stack (much faster than melt for wide data)
        # Set iso3 as index, select only time columns, stack
        stacked = filtered.set_index('iso3')[time_cols].apply(
            pd.to_numeric, errors='coerce'
        ).stack().reset_index()
        stacked.columns = ['iso3', 'period', 'reer']
        stacked = stacked.dropna(subset=['reer'])
        stacked['year'] = stacked['period'].str[:4].astype(int)

        annual = stacked.groupby(['iso3', 'year'])['reer'].mean().reset_index()
    else:
        # Long format
        print("    Attempting long format parse ...")
        tp_col = next((c for c in raw.columns if 'TIME' in c.upper() and 'PERIOD' in c.upper()), None)
        obs_col = next((c for c in raw.columns if 'OBS' in c.upper() and 'VALUE' in c.upper()), None)
        if tp_col is None or obs_col is None:
            print(f"    Could not find TIME_PERIOD/OBS_VALUE. Cols: {list(raw.columns[:30])}")
            return pd.DataFrame(columns=['iso3', 'year', 'reer'])

        raw['obs_val'] = pd.to_numeric(raw[obs_col], errors='coerce')
        raw = raw.dropna(subset=['obs_val'])

        if eer_type_col:
            mask = raw[eer_type_col].astype(str).str.contains('R$|Real|real', na=False, regex=True)
            if mask.sum() > 0:
                raw = raw[mask]
        if basket_col:
            mask = raw[basket_col].astype(str).str.contains('B$|Broad|broad', na=False, regex=True)
            if mask.sum() > 0:
                raw = raw[mask]

        raw['year'] = raw[tp_col].astype(str).str[:4].astype(int)
        raw['iso2'] = raw[ref_area_col].astype(str).str.strip()
        raw['iso3'] = raw['iso2'].map(BIS_TO_ISO3)
        raw = raw.dropna(subset=['iso3'])

        annual = raw.groupby(['iso3', 'year'])['obs_val'].mean().reset_index()
        annual.rename(columns={'obs_val': 'reer'}, inplace=True)

    print(f"    REER: {annual['iso3'].nunique()} countries, {len(annual)} obs")
    return annual


def download_wdi_indicator(indicator, name, max_retries=3):
    """Download a WDI indicator via World Bank API in 15-country chunks."""
    print(f"  Downloading WDI {indicator} ({name}) ...")

    # Get country list
    panel = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv",
                        usecols=['iso3'])
    countries = sorted(panel['iso3'].unique())

    all_rows = []
    chunk_size = 15
    for i in range(0, len(countries), chunk_size):
        chunk = countries[i:i + chunk_size]
        codes = ";".join(chunk)
        url = (f"https://api.worldbank.org/v2/country/{codes}/indicator/{indicator}"
               f"?format=json&per_page=5000&date=1970:2024")

        for attempt in range(max_retries):
            try:
                resp = requests.get(url, timeout=60)
                resp.raise_for_status()
                data = resp.json()
                if isinstance(data, list) and len(data) > 1:
                    for entry in data[1]:
                        if entry.get('value') is not None:
                            all_rows.append({
                                'iso3': entry['countryiso3code'],
                                'year': int(entry['date']),
                                name: float(entry['value']),
                            })
                break
            except Exception as e:
                if attempt < max_retries - 1:
                    time.sleep(2)
                else:
                    print(f"    Failed chunk {chunk[:3]}...: {e}")

        time.sleep(1.5)

    df = pd.DataFrame(all_rows)
    if len(df) > 0:
        df = df.drop_duplicates(subset=['iso3', 'year'])
        print(f"    {name}: {df['iso3'].nunique()} countries, {len(df)} obs")
    else:
        print(f"    {name}: no data retrieved")
    return df


# ── Main assembly ──────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly — Asset Returns Panel")
    print("=" * 70)

    # 1. Load base panel
    print("\n1. Loading full_panel.csv ...")
    panel = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    panel = panel[panel['year'] <= 2024].copy()
    print(f"   Base panel: {panel['iso3'].nunique()} countries, {len(panel)} obs, "
          f"{panel['year'].min()}-{panel['year'].max()}")

    # 2. Download BIS house prices
    print("\n2. BIS Residential Property Prices ...")
    try:
        raw_hpi = download_bis_zip(
            "https://data.bis.org/static/bulk/WS_SPP_csv_col.zip",
            "BIS Property Prices"
        )
        hpi = parse_bis_property_prices(raw_hpi)

        # Construct variables
        if len(hpi) > 0:
            hpi = hpi.sort_values(['iso3', 'year'])
            hpi['log_rhpi'] = np.log(hpi['hpi'].clip(lower=1))
            hpi['d_rhpi'] = hpi.groupby('iso3')['hpi'].pct_change() * 100
            hpi = hpi[['iso3', 'year', 'hpi', 'log_rhpi', 'd_rhpi']]
    except Exception as e:
        print(f"   BIS HPI download failed: {e}")
        hpi = pd.DataFrame(columns=['iso3', 'year', 'hpi', 'log_rhpi', 'd_rhpi'])

    # 3. Download BIS REER
    print("\n3. BIS Real Effective Exchange Rates ...")
    try:
        raw_reer = download_bis_zip(
            "https://data.bis.org/static/bulk/WS_EER_csv_col.zip",
            "BIS REER"
        )
        reer = parse_bis_reer(raw_reer)

        if len(reer) > 0:
            reer = reer.sort_values(['iso3', 'year'])
            reer['log_reer'] = np.log(reer['reer'].clip(lower=1))
            reer['d_reer'] = reer.groupby('iso3')['reer'].pct_change() * 100
            reer = reer[['iso3', 'year', 'reer', 'log_reer', 'd_reer']]
    except Exception as e:
        print(f"   BIS REER download failed: {e}")
        reer = pd.DataFrame(columns=['iso3', 'year', 'reer', 'log_reer', 'd_reer'])

    # 4. Download WDI stock market data
    print("\n4. WDI Stock Market Indicators ...")
    try:
        stock_cap = download_wdi_indicator("CM.MKT.LCAP.GD.ZS", "stock_market_cap_gdp")
    except Exception as e:
        print(f"   Stock market cap download failed: {e}")
        stock_cap = pd.DataFrame(columns=['iso3', 'year', 'stock_market_cap_gdp'])

    try:
        stock_traded = download_wdi_indicator("CM.MKT.TRAD.GD.ZS", "stocks_traded_gdp")
    except Exception as e:
        print(f"   Stocks traded download failed: {e}")
        stock_traded = pd.DataFrame(columns=['iso3', 'year', 'stocks_traded_gdp'])

    # 5. Merge everything
    print("\n5. Merging datasets ...")
    merged = panel.copy()

    if len(hpi) > 0:
        merged = merged.merge(hpi, on=['iso3', 'year'], how='left')
        print(f"   After HPI merge: {merged['hpi'].notna().sum()} non-null HPI obs")

    if len(reer) > 0:
        merged = merged.merge(reer, on=['iso3', 'year'], how='left')
        print(f"   After REER merge: {merged['reer'].notna().sum()} non-null REER obs")

    if len(stock_cap) > 0:
        merged = merged.merge(stock_cap, on=['iso3', 'year'], how='left')
        print(f"   After stock cap merge: {merged['stock_market_cap_gdp'].notna().sum()} non-null")

    if len(stock_traded) > 0:
        merged = merged.merge(stock_traded, on=['iso3', 'year'], how='left')
        print(f"   After stocks traded merge: {merged['stocks_traded_gdp'].notna().sum()} non-null")

    # Construct change in stock market cap
    if 'stock_market_cap_gdp' in merged.columns:
        merged = merged.sort_values(['iso3', 'year'])
        merged['d_stock_market_cap'] = merged.groupby('iso3')['stock_market_cap_gdp'].diff()

    # 6. Save
    out_path = PROCESSED_DIR / "asset_panel.csv"
    merged.to_csv(out_path, index=False)
    print(f"\n   Saved: {out_path}")
    print(f"   Shape: {merged.shape}")

    # 7. Summary statistics
    print("\n6. Summary statistics ...")
    asset_vars = ['real_bond_10y', 'real_short_3m', 'term_spread', 'govt_bond_10y',
                  'reer', 'log_reer', 'd_reer',
                  'hpi', 'log_rhpi', 'd_rhpi',
                  'stock_market_cap_gdp', 'stocks_traded_gdp', 'd_stock_market_cap',
                  'port_eq_assets_gdp', 'carry_vs_usa', 'carry_vs_jpn',
                  'fx_hedged_vs_usa', 'fx_hedged_vs_jpn']
    asset_vars = [v for v in asset_vars if v in merged.columns]

    stats_rows = []
    for var in asset_vars:
        s = merged[var].dropna()
        if len(s) == 0:
            continue
        n_countries = merged.loc[merged[var].notna(), 'iso3'].nunique()
        stats_rows.append({
            'Variable': var,
            'N': len(s),
            'Countries': n_countries,
            'Mean': s.mean(),
            'Std': s.std(),
            'Min': s.min(),
            'p25': s.quantile(0.25),
            'Median': s.median(),
            'p75': s.quantile(0.75),
            'Max': s.max(),
        })

    stats_df = pd.DataFrame(stats_rows)

    # Save as markdown table
    md_lines = ["# Summary Statistics: Asset Variables\n"]
    md_lines.append("| Variable | N | Countries | Mean | Std | Min | Median | Max |")
    md_lines.append("|---|---|---|---|---|---|---|---|")
    for _, row in stats_df.iterrows():
        md_lines.append(
            f"| {row['Variable']} | {row['N']:.0f} | {row['Countries']:.0f} "
            f"| {row['Mean']:.2f} | {row['Std']:.2f} | {row['Min']:.2f} "
            f"| {row['Median']:.2f} | {row['Max']:.2f} |"
        )
    md_lines.append(f"\n*Panel: {merged['iso3'].nunique()} countries, "
                    f"{merged['year'].min()}-{merged['year'].max()}*")

    stats_path = TABLES_DIR / "summary_statistics.md"
    stats_path.write_text('\n'.join(md_lines))
    print(f"   Saved: {stats_path}")

    print("\n" + "=" * 70)
    print("Phase 1 complete.")
    print("=" * 70)

    return merged


if __name__ == "__main__":
    main()
